x_values = torch.linspace(-M, N, n0).to(device)
y_values = torch.linspace(-M, N, n0).to(device)
theta_values = torch.linspace(-M, N, n1).to(device)
tensor = torch.stack([
    x_values.unsqueeze(-1).unsqueeze(-1).expand(n0, n0, n1),
    y_values.unsqueeze(0).unsqueeze(-1).expand(n0, n0, n1),
    theta_values.unsqueeze(0).unsqueeze(0).expand(n0, n0, n1)
], dim=-1).to(device)

f2_minus_values = f2_minus(tensor, M, N).to(device)
f2_plus_values = f2_plus(tensor, M, N).to(device)
f2_theta = torch.abs(f2_plus_values - f2_minus_values).to(device)
min_f2_theta_values, min_theta_indices = torch.min(f2_theta, dim=-1)
min_theta_values = theta_values[min_theta_indices].to(device)

x_grid, y_grid = torch.meshgrid(x_values, y_values, indexing='ij')
combined_tensor0 = torch.stack((x_grid, y_grid, min_theta_values), dim=-1)

num_x, num_y, _ = combined_tensor0.shape
u_values_1 = torch.linspace(-M, 1, numpoints_f2_minus_z).to(device)
u_values_2 = torch.linspace(-M, N, numpoints_f2_minus_z).to(device)
theta_values = combined_tensor0[..., 2].unsqueeze(-1)  # Shape: (num_x, num_y, 1)
expanded_theta = theta_values.expand(-1, -1, numpoints_f2_minus_z)  # Shape: (num_x, num_y, numpoints_f2_minus_z)
u_values_tensor_1 = -M + (u_values_1 - (-M)) * (expanded_theta - (-M)) / (1 - (-M))
u_values_tensor_2 = theta_values + (u_values_2 - (-M)) * (N - theta_values) / (N - (-M))

step_size1 = (expanded_theta - (-M)) / (numpoints_f2_minus_z - 1)
step_size2 = (N - expanded_theta) / (numpoints_f2_minus_z - 1)

expanded_x = combined_tensor0[..., 0].unsqueeze(-1).expand(-1, -1, numpoints_f2_minus_z)
expanded_y = combined_tensor0[..., 1].unsqueeze(-1).expand(-1, -1, numpoints_f2_minus_z)
combined_tensor2 = torch.stack((expanded_x, expanded_y, expanded_theta, u_values_tensor_1), dim=-1)
combined_tensor4 = torch.stack((expanded_x, expanded_y, expanded_theta, u_values_tensor_2), dim=-1)


x_grid_expanded = x_grid.unsqueeze(-1).expand(n0, n0, n2)
y_grid_expanded = y_grid.unsqueeze(-1).expand(n0, n0, n2)
min_theta_values_expanded = min_theta_values.unsqueeze(-1).expand(n0, n0, n2)
z_values = torch.linspace(-M, N, n2).to(device)
z_grid = z_values.unsqueeze(0).unsqueeze(0).expand(n0, n0, n2)
combined_tensor1 = torch.stack((x_grid_expanded, y_grid_expanded, min_theta_values_expanded, z_grid), dim=-1)

x_grid_expanded = combined_tensor1[..., 0]
y_grid_expanded = combined_tensor1[..., 1]
theta_values_expanded = combined_tensor1[..., 2]
z_values_expanded = combined_tensor1[..., 3]

mask_z_leq_theta = z_values_expanded <= theta_values_expanded
grid_1 = combined_tensor1.clone()
grid_1[~mask_z_leq_theta] = float('inf')  # Invert the mask to set z > theta points to inf
grid_2 = combined_tensor1.clone()
grid_2[mask_z_leq_theta] = float('inf')  # Use the mask to keep z > theta points

f_2_minus=f2_minus_z(combined_tensor2, grid_1, M, N, step_size1)
f_2_plus=f2_plus_z(combined_tensor4, grid_2, M, N, step_size2)

inf_mask = torch.isinf(f_2_minus[:, :, :, 0])
inf_mask = inf_mask.unsqueeze(-1)
f2_z = f_2_minus.clone()
f2_z[inf_mask.expand_as(f2_z)] = f_2_plus[inf_mask.expand_as(f2_z)]
torch.save(f2_z, 'f2_z.pt')


x = f2_z[..., 0]  # [n0, n0, n2]
y = f2_z[..., 1]  # [n0, n0, n2]
theta = f2_z[..., 2]  # [n0, n0, n2]
z = f2_z[..., 3]  # [n0, n0, n2]
f2_z_values = f2_z[..., 4]

f1_z_values = f1(z,x,y,M,N)
f2_divide_f1 = f2_z_values / f1_z_values
result = torch.cat((x.unsqueeze(-1), y.unsqueeze(-1), z.unsqueeze(-1), f2_divide_f1.unsqueeze(-1)), dim=-1)

sup_z = torch.max(f2_divide_f1, dim=-1).values
x_less_than_y = x[..., 0] < y[..., 0]
filtered_sup_z = torch.where(x_less_than_y, sup_z, torch.tensor(float('inf'), device=sup_z.device))
inf_x_less_than_y = torch.min(filtered_sup_z)
knn1 = 1/inf_x_less_than_y
print(knn1)

knn1_value = knn1.item()
workbook = xlsxwriter.Workbook('knn_results.xlsx')
worksheet = workbook.add_worksheet()
worksheet.write(0, 0, 'KNN1')
worksheet.write(1, 0, knn1_value)
